/*********************************************************************
Replication code for Systemic Discrimination Among Large U.S. Employers
Patrick M. Kline, Evan K. Rose, Christopher R. Walters
April, 2022

This code produces Table A4, which resports lower bounds on the prevalence
of job-level discrimination.

It requires top-level directory be set to the replication folder using
the global below.
*********************************************************************/

global dir "/accounts/projects/pkline/randres/randres/replication"
capture restore
clear all
set seed 126312

* File for collecting results
tempfile allresults
clear
save `allresults', replace emptyok

* Number of BS reps
global bsreps = 500

* Other options
local balanced 0
local trim 0

* Mata code to get ustat
cap mata: mata drop ustater()
mata
void ustater() {
    fsplitmeans = st_matrix("fsplitmeans")
    jweights = st_matrix("jweights")
    fweights = st_matrix("fweights")
    fobs = length(fweights)
    nf = sum(fweights)
    nobs = length(fsplitmeans[.,1])
    fids = J(nobs,fobs,0)
    k = 1
    for (i=1; i<=nobs; i++) {
        fids[i,k] = 1
        if (i < nobs) {
            if (fsplitmeans[i,1] != fsplitmeans[i+1,1]) {
                k++
            }
        }
    }
    weightsums = fids'*jweights
    fmeans = (fids'*(fsplitmeans[.,3]:*jweights)):/weightsums
    fmeans = fmeans:*fweights
    fmean_prod = fmeans*fmeans'
    fmean_prod = sum(fmean_prod - diag(fmean_prod)) / nf^2

    fsquares = J(fobs,1,0)
    for (k=1; k<=fobs; k++) {
        fjweights = select(jweights,fids[.,k])
        fjobs = select(fsplitmeans[.,3],fids[.,k]):*fjweights
        prod = fjobs * fjobs'
        wsum = fjweights * fjweights'
        fsquares[k] = sum(prod - diag(prod)) / sum(wsum - diag(wsum))
    }
    fsquares = fsquares:*(fweights:*(nf:-fweights))
    
    btwnvar = quadsum(fsquares)/nf^2 - fmean_prod
    st_matrix("btwnvar", btwnvar)
    }
end


* Function for split estimator of variance
capture program drop withinustat
program define withinustat, rclass
    args attribute split

    * Default to 1-4 vs. 5-8 split
    if "`split'" == "" {
        local split = 5
    }

    preserve
    qui gen split_group = order >= `split'
    gen fid = firm_id

    bys jid split_group: egen _tmpw = first(w)
    qui replace w = _tmpw

    collapse (mean) cb (firstnm) fid w, by(jid split_group `attribute')
    qui reshape wide cb, i(jid fid split_group w) j(`attribute')
    qui gen dif = cb1 - cb0
    drop cb0 cb1
    qui drop if dif == .

    * Keep if multiple split groups
    bys jid: egen nsplits = nvals(split_group)
    qui keep if nsplits == 2

    * Weights to get firm-weighted estimates
    bys fid: egen njobs = nvals(jid)
    gen double wt = 125/njobs
    egen tg = tag(jid) 

    * Do total variance
    sort jid split_group 
    mkmat jid split_group dif, matrix(fsplitmeans)
    * matrix jweights = J(`= rowsof(fsplitmeans)',1,1)
    *matrix fweights = J(`= rowsof(fsplitmeans)'/2,1,1)
    mkmat w, matrix(jweights)
    mkmat wt if tg == 1, matrix(fweights)

    * Get ustat estimate
    mata: ustater()  
    local btwnvar_wt = btwnvar[1,1]
    return scalar total_btwnvar_wt = `btwnvar_wt'
    return scalar total_btwnsd_wt = sqrt(`btwnvar_wt')

    qui replace wt = 1
    mkmat wt if tg == 1, matrix(fweights)
    mata: ustater()  
    local btwnvar = btwnvar[1,1]
    return scalar total_btwnvar = `btwnvar'
    return scalar total_btwnsd = sqrt(`btwnvar')

    restore
end

capture program drop estfirmjobs
program define estfirmjobs, rclass
    * Arguments
    args attribute

    * Set argument for attribute, keeping jobs that got both offs 
    preserve
    bys jid: egen _tmpw = first(w)
    qui replace w = _tmpw

    collapse (mean) cb (firstnm) w, by(jid `attribute')
    qui reshape wide cb, i(jid w) j(`attribute')
    qui gen cb = cb1 - cb0
    qui drop if cb == .

    * Weights to get firm-weighted estimates
    qui reg cb [iweight=w], vce(robust)

    return scalar jmean = _b[_cons]
    return scalar jmean_se = _se[_cons]
    restore
end

capture program drop bootprev, rclass
program define bootprev, rclass
    args nboots attribute split

    qui replace w = 1
    withinustat `attribute'
    return scalar total_btwnvar = r(total_btwnvar)
    local total_btwnvar = r(total_btwnvar)
    return scalar total_btwnsd = r(total_btwnsd)
    return scalar total_btwnvar_wt = r(total_btwnvar_wt)
    return scalar total_btwnsd_wt = r(total_btwnsd_wt)

    * Estiamte means
    estfirmjobs `attribute'
    return scalar jmean = r(jmean)
    return scalar jmean_se = r(jmean_se)
    local jmean = r(jmean)
    local jmean_se = r(jmean_se)

    * Estimate prevalence
    local mean_sq = (`jmean')^2 - (`jmean_se')^2
    return scalar prev = `mean_sq'/(`mean_sq' + `total_btwnvar')

    matrix bsres = J(`nboots',6,.)
    foreach s of numlist 1/`nboots' {
        di "Working on bootstrap loop `s'"
        qui replace w = -ln(uniform())

        withinustat `attribute'
        matrix bsres[`s',1] = r(total_btwnvar)
        local total_btwnvar = r(total_btwnvar)
        matrix bsres[`s',2] = 1
        matrix bsres[`s',3] = r(total_btwnvar_wt)
        matrix bsres[`s',4] = 1

        * Estiamte means
        estfirmjobs `attribute'
        matrix bsres[`s',5] = r(jmean)
        local jmean = r(jmean)
        local jmean_se = r(jmean_se)

        * Estimate prevalence
        local mean_sq = (`jmean')^2 - (`jmean_se')^2
        matrix bsres[`s',6] = `mean_sq'/(`mean_sq' + `total_btwnvar')
    }
    mat li bsres
    * Return the standard deviations of both
    mata: bsvar = quadvariance(st_matrix("bsres")[.,1..6])
    mata: st_matrix("bsvar", bsvar)
    return scalar se_total_btwnvar = sqrt(bsvar[1,1])
    return scalar se_total_btwnsd = sqrt(bsvar[2,2])
    return scalar se_total_btwnvar_wt = sqrt(bsvar[3,3])
    return scalar se_total_btwnsd_wt = sqrt(bsvar[4,4])
    return scalar se_jmean = sqrt(bsvar[5,5])
    return scalar se_prev = sqrt(bsvar[6,6])
end

* Load data
use ${dir}/data/data.dta, clear
    
* Constructed variables
bys firm_id: egen mean_cb = mean(cb)
gen w = 1
egen jid = group(firm_id job_id)

* Loop over attributes
foreach attribute of varlist black female over40  {
    preserve
    * Estimate variance
    bootprev ${bsreps} `attribute'
    local total_btwnvar = r(total_btwnvar)
    local total_btwnsd = r(total_btwnsd)
    local total_btwnvar_wt = r(total_btwnvar_wt)
    local total_btwnsd_wt = r(total_btwnsd_wt)
    local jmean = r(jmean)
    local jmean_se = r(jmean_se)
    local prev = r(prev)

    local se_total_btwnvar = r(se_total_btwnvar)
    local se_total_btwnsd = r(se_total_btwnsd)
    local se_total_btwnvar_wt = r(se_total_btwnvar_wt)
    local se_total_btwnsd_wt = r(se_total_btwnsd_wt)
    local se_jmean = r(se_jmean)
    local se_prev = r(se_prev)

    * Save output
    clear
    set obs 1
    gen attribute = "`attribute'"
    gen total_btwnvar = `total_btwnvar'
    gen total_btwnsd = `total_btwnsd'
    gen total_btwnvar_wt = `total_btwnvar_wt'
    gen total_btwnsd_wt = `total_btwnsd_wt'
    gen jmean = `jmean'
    gen jmean_se = `jmean_se'
    gen prev = `prev'

    gen se_total_btwnvar = `se_total_btwnvar'
    gen se_total_btwnsd = `se_total_btwnsd'
    gen se_total_btwnvar_wt = `se_total_btwnvar_wt'
    gen se_total_btwnsd_wt = `se_total_btwnsd_wt'
    gen se_jmean = `se_jmean'
    gen se_prev = `se_prev'

    append using `allresults'
    save `allresults', replace
    list if _n == 1
    restore
}
use `allresults', clear
outsheet using ${dir}/dump/tableA4.csv, comma replace


